ScatterNdUpdate ================= 根据索引(indices)将更新值(updates)散布并更新到输出张量(output)的指定切片中。 .. math:: output[indices[i]] = updates[i] 该算子通过 ``indices`` 指定的坐标,定位到 ``output`` 中的特定子部分(切片),并使用 ``updates`` 中对应的值进行覆盖更新。 输入: - **output** - 待更新的输出张量地址(输入/输出)。 - **output_shape** - 输出张量的形状数组地址。 - **output_ndim** - 输出张量的维度数。 - **indices** - 索引张量数据地址,其最后一个维度代表索引深度。 - **indices_shape** - 索引张量的形状数组地址。 - **indices_ndim** - 索引张量的维度数。 - **updates** - 更新数据源地址,其形状必须与索引定位出的切片形状一致。 - **core_mask(int, 可选)** - 核掩码(仅适用于共享存储版本)。 输出: - **output** - 更新后的结果地址。 支持平台: ``FT78NE`` ``MT7004`` .. note:: - FT78NE 支持 int8, int16, int32, fp32, fp64, cplx64, cplx128 - MT7004 支持 fp16, fp32, int16, int32, cplx64 - indices 数组的类型固定为 int32。 - 算子支持张量维度最大为 8 维。 - 共享存储版本内部使用 DMA 传输加速,直接在 DDR 空间进行切片覆盖。 **共享存储版本:** .. c:function:: void i8_scatter_nd_update_s(int8_t* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, int8_t* updates, int core_mask) .. c:function:: void i16_scatter_nd_update_s(int16_t* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, int16_t* updates, int core_mask) .. c:function:: void i32_scatter_nd_update_s(int* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, int* updates, int core_mask) .. c:function:: void hp_scatter_nd_update_s(half* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, half* updates, int core_mask) .. c:function:: void fp_scatter_nd_update_s(float* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, float* updates, int core_mask) .. c:function:: void dp_scatter_nd_update_s(double* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, double* updates, int core_mask) .. c:function:: void c64_scatter_nd_update_s(float* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, float* updates, int core_mask) .. c:function:: void c128_scatter_nd_update_s(double* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, double* updates, int core_mask) **C调用示例:** .. code-block:: c :linenos: :emphasize-lines: 16 // FT78NE 示例(共享存储) #include #include "78NE/utils.h" int main() { float *output = (float *)0xA0000000; // 原始张量在 DDR float *updates = (float *)0xB0000000; // 更新值在 DDR int *indices = (int *)0xC0000000; // 索引在 DDR int out_shape[] = {4, 4, 4}; int ind_shape[] = {5, 2}; // 更新5个切片,每个索引深度为2 int out_ndim = 3; int ind_ndim = 2; int core_mask = 0xFF; // 使用8核并行 fp_scatter_nd_update_s(output, out_shape, out_ndim, indices, ind_shape, ind_ndim, updates, core_mask); return 0; } **私有存储版本:** .. c:function:: void i8_scatter_nd_update_p(int8_t* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, int8_t* updates) .. c:function:: void i16_scatter_nd_update_p(int16_t* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, int16_t* updates) .. c:function:: void i32_scatter_nd_update_p(int* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, int* updates) .. c:function:: void hp_scatter_nd_update_p(half* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, half* updates) .. c:function:: void fp_scatter_nd_update_p(float* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, float* updates) .. c:function:: void dp_scatter_nd_update_p(double* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, double* updates) .. c:function:: void c64_scatter_nd_update_p(float* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, float* updates) .. c:function:: void c128_scatter_nd_update_p(double* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, double* updates) **C调用示例:** .. code-block:: c :linenos: :emphasize-lines: 12 #include int main() { float *output = (float *)0x10810000; float *updates = (float *)0x10820000; int *indices = (int *)0x10830000; int out_shape[] = {4, 4, 4}; int ind_shape[] = {5, 2}; // 调用单核版本 fp_scatter_nd_update_p(output, out_shape, 3, indices, ind_shape, 2, updates); return 0; }